import json
import os
import sys
import time

from utils.datasetUtils import splitTrainDataByNums
from utils.modelUtils import getAvailableDevice, getResNet18
from utils.randomUtils import all_combinations

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))

from _dynamic_dataset.AccCombinationRecord import AccCombinationRecord, WriteRecordsToFile
from _dynamic_dataset.TrainDataset import TrainSubset
from rootPath import project_path
from server_client.Client import Client
from server_client.Server import Server
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from datetime import datetime

'''
实验目的：验证各方数据集在相同大小和相同分布的情况下的真实夏普利值的计算
'''

# 获取当前日期和时间
now = datetime.now()
# 超参数
batch_size = 128
learning_rate = 0.0001
num_epochs = 40
# threshold = args.threshold
# maxNum = args.maxNum

# 读取配置文件
# 基于梯度的夏普利值计算
with open(project_path + "/conf/client.json", 'r') as f:
    clientConf = json.load(f)

with open(project_path + "/conf/server.json", 'r') as f:
    serverConf = json.load(f)

# # 配置日志记录器
# dirPath = project_path+"/experiments/experiments3/mr/sdss/logs"+f"/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}_{now.second:02d}"
# # 创建目录（如果不存在）
# os.makedirs(dirPath, exist_ok=True)
# logging.basicConfig(filename=f'{dirPath}/main.log', level=logging.INFO)
# 数据预处理
train_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# 测试集数据预处理
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 客户端的数据分配创建
numsList = [
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
    {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000},
]
clientNum = len(numsList)
clients = []
print("生成id组合完成")

device = getAvailableDevice()
# 服务器的验证集
test_set = torchvision.datasets.CIFAR10(root=project_path + '/data', train=False, download=True,
                                        transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
# 测试集
train_dataset = torchvision.datasets.CIFAR10(root=project_path + '/data/', train=True, download=True,
                                             transform=train_transform)
model = getResNet18()
datasetIndices = splitTrainDataByNums(orderList=numsList, trainDataset=train_dataset)
for i in range(clientNum):
    trainSubDataset = TrainSubset(train_dataset, datasetIndices[i])
    clients.append(
        Client(clientConf, model, device, DataLoader(trainSubDataset, shuffle=True, batch_size=128, num_workers=2), i))
    print(f"客户端{i}创建成功")
print("客户端全部初始化成功！")

# 生成客户端的组合
clientIds = []
for i in range(clientNum):
    clientIds.append(i)

combinations = all_combinations(clientIds)

acc_record = AccCombinationRecord()
# 服务器
server = Server(serverConf, test_loader, device, model)

# 开始时间
start_time = time.time()
for epoch in range(1, num_epochs + 1):
    # logging.info(f"epoch={epoch}")
    # 客户端训练
    for client in clients:
        client.train(server.global_model)
    # 服务器梯度聚合并计算对应的正确率
    for combination in combinations:
        # 创建客户端子集
        com_clients = []
        for clientId in combination:
            com_clients.append(clients[clientId])
        server.model_aggregate(com_clients, server.sub_model)
        # 在测试集上进行测试
        correct = 0
        total = 0
        server.sub_model.eval()
        with torch.no_grad():
            for data in server.eval_loader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = server.sub_model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        # 真实的正确率
        acc = 100 * correct / total
        acc_record.addCombinationAcc(epoch, combination, acc)
        print(f"epoch={epoch},acc={acc}")
    # 更新全局模型、sub_model
    server.model_aggregate(clients, server.global_model)
# 结束时间
end_time = time.time()
# logging.info(f"总耗时为{end_time-start_time}s")
# WriteRecordsToFile(dirPath+"/acc_records.json", acc_record)
